#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import argparse
import ast
import collections
import distutils.spawn
import imp
import json
import logging
import os
import pipes
import re
import resource
import shutil
import subprocess
import sys
import tempfile
import textwrap

import common

logger = logging.getLogger(os.path.basename(__file__))

IN_PROD_TREE = common.IN_PROD_TREE

# Some platforms preface function symbols with an underscore, some do not.
# TODO: Compute this in some more portable way.
symPrefix = '_' if sys.platform == 'darwin' else ''


# Filter out any extracted disassembly that isn't found by searching through
# all llvm function references made starting with the files in sources
def filterDisassembly(disasm, sources):
    referencedFuncs = set(filter(
        lambda func: func[1][0][1] in sources,
        disasm
    ))
    # queue of blocks to search through
    queue = collections.deque(referencedFuncs)

    def symbolSet(llvm):
        # grabs llvm functions (prefixed with @)
        llvmFuncRE = r'@([\w.]+)'
        return set(re.findall(llvmFuncRE, llvm, flags=re.M))

    while len(queue):
        # get references used in first function on the queue
        symbols = symbolSet(queue.popleft()[2])
        # map those references to their functions in disasm
        newFuncs = set(filter(lambda func: func[0] in symbols, disasm))
        referencedFuncs.update(newFuncs)
        # add non-repeated references to the queue
        queue.extend(filter(lambda func: func not in referencedFuncs, newFuncs))

    return list(referencedFuncs)


# get rid of .loc lines, .file lines, and headers that are empty/not referenced
# logic duplicated from the playground's "Hide locations" and "Hide dead labels"
def cleanDisassembly(asm):
    labelRE = re.compile(r'^([^_][^: \t]+):')
    def cleanLine(line):
        if line.startswith('\t.loc') or line.startswith('\t.file'):
            return False
        labelMatch = re.search(labelRE, line)
        if labelMatch:
            return len(asm.split(labelMatch.group(1))) > 2
        return True

    return '\n'.join(filter(cleanLine, asm.split('\n')))


def extractAsmFunctions(asm, names):
    """Given an input stream for a .s file, returns a dict containing the
    assembly code for every function in "names".
    """

    # Get rid of @PLT ending on any call/jmp instruction
    pltRE = r'((?:call|jmp)(?:q|l|w)?\s+[\w.]+)@PLT'
    asmNoPLT = re.sub(pltRE, r'\1', asm, flags=re.MULTILINE)

    # Grab text for any desired symbol up to the first .cfi_endproc or
    # the .size directive for that function.
    symRE = (r'^((%s(%s))\s*:.*?)$\s*(?:\.cfi_endproc|(?:\.size\s+\2,.*?))$'
             % (symPrefix, '|'.join(re.escape(n) for n in names)))

    funcs = {}
    for m in re.finditer(symRE, asmNoPLT, flags=re.MULTILINE | re.DOTALL):
        name = m.group(3)
        code = m.group(1)

        # Discard CFI clutter from the asm code.
        body = ''.join(line for line in code.splitlines(True)
                       if not re.match(r'\s*\.cfi_', line))

        # Guarantee a trailing newline to be consistent.
        funcs[name] = body if body == '' or body.endswith('\n') else body + '\n'

    return funcs


def emitDisassembly(llvmFile, asmFile, output, useJSON, filterDisasm, cleanDisasm, disasmFile):
    # Parse stylized comments in the .ll file to figure out what symbols the
    # Skip compiler says we asked to disassemble, and exactly what Skip code
    # those symbols correspond to.
    #
    # The syntax is slightly goofy for ease of parsing using ast.literal_eval.
    # There are a sequence of @disasm_functions containing the function's
    # name, source file, character start offset and character end offset,
    # followed by the text of the LLVM code, followed by a single
    # @disasm_symbol defining the LLVM function name emitted for the
    # preceding @disasm_function entries.
    #
    # We have a sequence of Skip functions, rather than just one, because
    # sometimes the compiler may collapse identical machine code for multiple
    # functions into a single function, especially when generics are involved.
    #
    # Note that the LLVM symbol may not literally be the same as the assembly
    # code symbol because some platforms prepend an underscore to them.
    #
    # ;;; @disasm_function: ("Buffer<Foo>.mreverse", "prelude/Buffer.sk", 1201, 1500)
    # ;;; @disasm_function: ("Buffer<Bar>.mreverse", "prelude/Buffer.sk", 1201, 1500)
    #
    # <assembly function body>
    #
    # ;;; @disasm_symbol: "sk.Buffer.mreverse.1"
    #
    # If we somehow get disassembly for a machine-generated function
    # does not directly correspond to user source, its "@disasm_function"
    # entry will consist of a tuple of only the first value:
    #
    # ;;; @disasm_function: ("<builtin>.someFunc",)
    #
    skFunctions = []
    llvm = []
    extracted = []
    for line in llvmFile:
        m = re.match(r';;; @disasm_(function|symbol): (.*)', line)
        if m:
            val = ast.literal_eval(m.group(2))
            if m.group(1) == 'function':
                # This describes an upcoming symbol.
                skFunctions.append(val)
            else:
                # Map this symbol to all of the Skip functions described
                # on earlier lines.
                extracted.append((val, tuple(skFunctions), ''.join(llvm)))
                skFunctions = []
                llvm = []
        elif skFunctions:
            # We are in a function, so collect its LLVM body.
            llvm.append(line)

    if not extracted:
        print('Found no matching functions to disassemble.', file=sys.stderr)
        sys.exit(1)

    if skFunctions:
        print('Skip function described but has no symbol.', file=sys.stderr)
        sys.exit(1)

    if filterDisasm:
        extracted = filterDisassembly(extracted, disasmFile)

    extractedFuncs = extractAsmFunctions(asmFile, [s[0] for s in extracted])

    def printHumanDisasm(out):
        first = True
        output = ''
        for symbol, functions, llvm in extracted:
            code = extractedFuncs.get(symbol)
            if code is not None:
                if not first:
                    output += '\n\n'
                first = False

                for f in functions:
                    output += ('# %s\n' % (f[0],))
                output += code
        if cleanDisasm:
            output = cleanDisassembly(output)
        out.write(output)

    def printJSONDisasm(out):
        entries = []

        # Build up a dictionary of all the disassembly data.
        for symbol, functions, llvm in extracted:
            code = extractedFuncs.get(symbol)
            if code is not None:
                jsonFunctions = [
                    collections.OrderedDict(
                        zip(('name', 'file', 'start', 'end'), f)
                    ) for f in functions
                ]

                entries.append(
                    collections.OrderedDict(
                        (('functions', jsonFunctions),
                         ('symbol', symPrefix + symbol),
                         ('asm', code),
                         ('llvm', llvm))))

        json.dump({'disassembly': entries}, out, indent=4)

    printDisasm = printJSONDisasm if useJSON else printHumanDisasm

    if output is None or output == '-':
        printDisasm(sys.stdout)
    else:
        with open(output, 'w') as out:
            printDisasm(out)


def which(name):
    if os.path.isfile(name):
        return name

    if '/' in name:
        return None

    path = distutils.spawn.find_executable(name)
    if path and os.path.isfile(path):
        return path

    return None


def findProgram(args, name, env_var, cmake_var):
    if env_var:
        # Look for env_var (like 'CC')
        prog = os.environ.get(env_var, None)
        if prog:
            path = which(prog)
            if not path:
                # If they explicitly specified the env var and it doesn't exist
                # then complain.
                logger.error('%s (from environment variable %s) not found',
                             prog, env_var)
                sys.exit(1)
            return path

    if not IN_PROD_TREE:
        # Look in cmake's cache
        cache = os.path.join(common.build_dir, 'CMakeCache.txt')
        if os.path.isfile(cache):
            with open(cache, 'r') as f:
                data = [x.strip() for x in f if x.startswith(cmake_var)]
            if data:
                path = data[0].split('=', 1)[1]
                if os.path.isfile(path):
                    return path

    # Look for the program in the PATH by name
    path = which(name)
    if path:
        return path
    logger.error('%s not found in PATH or as %s in CMakeCache.txt', name,
                 cmake_var)
    sys.exit(1)



def findClang(args):
    return findProgram(args, 'clang++', 'CLANG', 'CLANG_EXECUTABLE')


def findNode(args):
    return findProgram(args, 'node', 'NODE', 'NODE')


def main(stack):
    parser = argparse.ArgumentParser(
        description='Run the Skip compiler',
        parents=[common.commonArguments(needsBackend=False, backend='native')])
    # Hidden arg used by install to provide the original unzipped path to skip_to_native
    parser.add_argument('--arg0', default=None, help=argparse.SUPPRESS)
    parser.add_argument('--preamble')
    parser.add_argument('-o', '--output', required=True)
    parser.add_argument('--llvm-temp', default=None, help=textwrap.dedent('''\
        The name of the file to use for llvm assembly output.  If not specified
        will use a temporary file and be deleted after compilation is done.  The
        natural extension for this file is .ll
        '''))
    parser.add_argument('--parallel', type=int, help=textwrap.dedent('''\
        How many LLVM processes to run in parallel
        '''))
    parser.add_argument('srcs', metavar='SOURCE', nargs='+')
    parser.add_argument(
        '--disasm-file', action='append',
        help='Emit disassembly for every function in the specified '
        'source file')
    parser.add_argument(
        '--disasm-function', action='append',
        help='Emit disassembly for the specified function or method. '
        'Function names look like "Module.function", and method names look '
        'like "Module.class::method. If not in a module, the "Module." prefix '
        'should be skipped.')
    parser.add_argument(
        '--disasm-all', action='store_true',
        help='Emit disassembly for everything needed by "main" or other '
        'disassembled functions. Unnecessary functions are not compiled so '
        'they do not get disassembled.')
    parser.add_argument(
        '--disasm-annotated', action='store_true',
        help='Emit disassembly for all functions and methods annotated with '
        '@disasm.')
    parser.add_argument(
        '--json', action='store_true',
        help="Emit disassembly in JSON format.")
    parser.add_argument(
        '--m32', action='store_true',
        help="Generate 32 bit code.")
    parser.add_argument(
        '--emit-llvm', action='store_true',
        help="Emit llvm bitcode (.ll) to output file")
    parser.add_argument('--nbe-flags', default='',
        help='Comma separated list of parameters to pass to the NBE compiler')
    parser.add_argument(
        '--filter-disasm', action='store_true',
        help='Filter disassembly to only emit asm needed by functions '
        'defined in the file given by --disasm-file')
    parser.add_argument(
        '--clean-disasm', action='store_true',
        help='Remove .loc, .file, and empty labels from the disassembly')

    if not IN_PROD_TREE:
        parser.add_argument('--via-backend',
                            help='Choose a backend to run.  Should be either '
                            + "'js', 'stage1' or 'native'")

    def isClangArg(arg):
        return (
            arg.startswith(('-f', '-l', '-m', '-std=', '-stdlib=', '-D', '-I', '-O', '-W')) or
            arg in ('-g',))

    argv = []
    CFLAGS = []
    for arg in sys.argv[1:]:
        (argv, CFLAGS)[isClangArg(arg)].append(arg)

    args = common.parse_args(parser, argv)

    if args.arg0:
        sys.argv[0] = args.arg0

    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)

    if not args.preamble:
        preamble = 'preamble32.ll' if args.m32 else 'preamble.ll'
        if IN_PROD_TREE:
            args.preamble = common.prefixRelPath('lib/skip', preamble)
        else:
            args.preamble = common.pathRelativeTo(
                common.binary_dir, 'runtime/native/lib', preamble)

    if not args.prelude:
        if IN_PROD_TREE:
            args.prelude = common.prefixRelPath('lib/skip/prelude')
        else:
            args.prelude = common.prefixRelPath('prelude')
    else:
        args.prelude = os.path.normpath(os.path.abspath(args.prelude))

    emitLL = args.emit_llvm

    # Are we disassembling rather than producing a .o file?
    disasm = bool(args.disasm_file or
                  args.disasm_function or
                  args.disasm_all or
                  args.disasm_annotated)

    args.output = os.path.normpath(os.path.abspath(args.output))
    common.ensureDirPathExists(os.path.dirname(args.output))

    if args.llvm_temp:
        llvmFiles = [args.llvm_temp]
    else:
        if disasm or emitLL:
            numFiles = 1
        elif args.parallel is not None:
            numFiles = max(1, args.parallel)
        else:
            parallel = os.environ.get('SKIP_LLVM_PARALLEL', None)
            if parallel is not None:
                numFiles = max(1, int(parallel))
            else:
                # By default, generate one up to one .ll+.o file per core.
                import multiprocessing
                numFiles = multiprocessing.cpu_count()

        llvmFiles = [
            stack.enter_context(
                tempfile.NamedTemporaryFile(prefix='tmp.gen.' + os.path.basename(args.output) + '.', suffix='.ll', delete=(not args.keep_temp))).name
            for x in xrange(numFiles)
        ]

    if args.keep_temp:
        logger.debug('KEEPING TEMP FILE: Saving .ll to %s',
                     ', '.join(llvmFiles))

    for f in llvmFiles:
        common.ensureDirPathExists(os.path.dirname(f))

    extraOpts = ['--export-function-as', 'main=skip_main']
    # SKIP_NBE_FLAGS should be a comma-separated list (like gcc's -Wl,args)
    nbe_flags = os.environ.get('SKIP_NBE_FLAGS', '')
    if nbe_flags:
        extraOpts += nbe_flags.split(',')

    if args.nbe_flags:
        extraOpts += args.nbe_flags.split(',')

    skip_flags = os.environ.get('SKIP_FLAGS', '')
    if skip_flags:
        extraOpts += skip_flags.split(',')

    if disasm:
        # Ensure any file we are disassembling as also in the srcs list.
        # Pass through disassembly flags to the native compiler.
        if args.disasm_all:
            extraOpts.append('--disasm-all')
        if args.disasm_annotated:
            extraOpts.append('--disasm-annotated')
        if args.disasm_file:
            missing = set(args.disasm_file) - set(args.srcs)
            if missing:
                print('Requested disassembly of %s not in list of '
                      'files to be compiled.\n' %
                      (', '.join(sorted(map(repr, missing)))), file=sys.stderr)
                sys.exit(1)

            for f in args.disasm_file:
                extraOpts += ['--disasm-file', f]
        if args.disasm_function:
            for f in args.disasm_function:
                extraOpts += ['--disasm-function', f]

    if args.preamble:
        extraOpts += ['--preamble', args.preamble]

    CC = findClang(args)
    NODE = findNode(args)

    LLC = os.environ.get('LLC', 'llc')
    if not os.path.isabs(LLC) or not os.path.isfile(LLC):
        # First look in the same place as clang
        LLC = os.path.join(os.path.dirname(CC), 'llc')
        if not os.path.isfile(LLC):
            LLC = distutils.spawn.find_executable('llc')
    # A missing LLC is fine - we just can't check the .ll file validity
    if LLC:
        OPT = os.path.join(os.path.dirname(LLC), 'opt')
        if not os.path.isfile(OPT):
            # We need to find both opt and llc, else just use clang.
            LLC = None
            OPT = None

    # Run skip_to_llvm on our input
    if IN_PROD_TREE:
        cmd = (common.prefixRelPath('bin/skip_to_llvm'),)
    else:
        cmd = (os.path.join(args.via_backend, 'skip_to_llvm'),)

    for f in llvmFiles:
        cmd += ('--output', f)

    cmd = cmd + tuple(args.srcs) + tuple(extraOpts)

    # print the full command then exit
    if args.print_skip_to_llvm:
        cmd_abs_path = os.path.abspath(os.path.join(common.source_dir, '..', cmd[0]))
        print(cmd_abs_path + '\n' + ' '.join(cmd[1:]))
        exit(0)

    with common.PerfTimer('skip_to_llvm.runtime'):
        common.callHelper(cmd, env={'NODE': NODE})

    if emitLL:
        with common.PerfTimer('emitLL.runtime'):
            shutil.copy(llvmFiles[0], args.output)
    elif disasm:
        # Run clang to parse its .S output.
        cmd = (CC, '-fPIC', llvmFiles[0], '-Wno-unused-command-line-argument')
        cmd += ('-S', '-o', '-') + tuple(CFLAGS)
        with common.PerfTimer('disasm.runtime'):
            asm = subprocess.check_output(cmd)
            with open(llvmFiles[0], 'r') as f:
                emitDisassembly(f, asm, args.output, args.json, args.filter_disasm, args.clean_disasm, args.disasm_file)
    else:
        # Generate our .o from our .ll

        # skip_to_llvm may have decided it didn't need all the files we offered.
        usedLLVMFiles = [x for x in llvmFiles if os.path.getsize(x)];

        optFlag = os.environ.get('SKIP_LLVM_OPT', None)
        if optFlag is None:
            # Find the last user-specified optimization flag.
            userFlag = '-O0'
            nonOptFlags = []
            for c in CFLAGS:
                if c.startswith('-O'):
                    userFlag = c
                else:
                    nonOptFlags.append(c)
            CFLAGS = nonOptFlags

            # We override the CFLAGS optimization level because -O1 seems to
            # perform nearly as well as higher levels for skip_to_llvm output,
            # and compiles much more quickly.
            optFlag = '-O0' if userFlag == '-O0' else '-O1'
        else:
            choices = ('-O', '-O0', '-O1', '-O2', '-O3')
            if optFlag not in choices:
                print('Unhandled value for SKIP_LLVM_OPT; choices '
                      'are %s\n' % (', '.join(choices),), file=sys.stderr)
                sys.exit(1)
            if optFlag == '-O':
                optFlag = '-O1'

        # Launch compilation subprocesses
        subprocesses = []
        objFiles = []

        for llvmFile in usedLLVMFiles:
            base = os.path.basename(os.path.splitext(llvmFile)[0])

            if len(usedLLVMFiles) == 1:
                objFile = args.output
            else:
                objFile = tempfile.NamedTemporaryFile(
                    prefix='tmp.obj.' + base + '.',
                    suffix='.o',
                    delete=(not args.keep_temp)).name

            objFiles.append(objFile)

            if LLC:
                optCommand = [OPT, optFlag, llvmFile]
                logger.debug('Running: ' +
                             ' '.join(map(pipes.quote, optCommand)))
                optProcess = subprocess.Popen(optCommand,
                                              stdout=subprocess.PIPE)
                subprocesses.append(optProcess)

                llcCommand = (
                    LLC,
                    optFlag.replace('-O', '-O='),
                    '-relocation-model=pic',
                    '-filetype=obj',
                    '-data-sections',
                    '-function-sections',
                    '-disable-fp-elim',
                    '-o=' + objFile
                )
                logger.debug('Running: ' +
                             ' '.join(map(pipes.quote, llcCommand)))
                llcProcess = subprocess.Popen(llcCommand, env=os.environ,
                                              stdin=optProcess.stdout)
                subprocesses.append(llcProcess)
            else:
                cmd = (
                    CC,
                    '-fno-omit-frame-pointer',
                    '-mno-omit-leaf-frame-pointer',
                    '-fPIC',
                    '-Wno-unused-command-line-argument',
                    '-c',
                    '-g',
                    '-o', objFile,
                    llvmFile,
                ) + tuple(CFLAGS) + (optFlag,)
                logger.debug('Running: ' + ' '.join(map(pipes.quote, cmd)))
                subprocesses.append(subprocess.Popen(cmd, env=os.environ))

        # Wait for all subprocesses to finish, even if some fail, so
        # we can clean up temp files properly.
        failed = False
        for p in subprocesses:
            if p.wait():
                failed = True

        if failed:
            print('Compilation failed', file=sys.stderr)
            return 1

        if len(objFiles) == 1:
            if objFiles[0] != args.output:
                os.rename(objFiles[0], args.output)
        else:
            cmd = ('ld', '-r', '-o', args.output,) + tuple(objFiles)
            logger.debug('Running: ' + ' '.join(map(pipes.quote, cmd)))
            subprocess.check_call(cmd, env=os.environ)

if __name__ == '__main__':
    with common.ExitStack() as stack:
        rc = main(stack)
    sys.exit(rc)
